# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

BucketIter <- setRefClass("BucketIter", fields = c("buckets", "bucket.names", "batch.size", 
                                                   "data.mask.element", "shuffle", "bucket.plan", "bucketID", "epoch", "batch", "batch.per.bucket", 
                                                   "last.batch.pad", "batch.per.epoch", "seed"), 
                          methods = list(
                            initialize = function(buckets, 
                                                  batch.size, data.mask.element = 0, shuffle = FALSE, seed = 123) {
                              .self$buckets <- buckets
                              .self$bucket.names <- names(.self$buckets)
                              .self$batch.size <- batch.size
                              .self$data.mask.element <- data.mask.element
                              .self$epoch <- 0
                              .self$batch <- 0
                              .self$shuffle <- shuffle
                              .self$batch.per.bucket <- 0
                              .self$batch.per.epoch <- 0
                              .self$bucket.plan <- NULL
                              .self$bucketID <- NULL
                              .self$seed <- seed
                              .self
                            }, reset = function() {
                              buckets_nb <- length(bucket.names)
                              buckets_id <- seq_len(buckets_nb)
                              buckets.size <- sapply(.self$buckets, function(x) {
                                tail(dim(x$data), 1)
                              })
                              .self$batch.per.bucket <- ceiling(buckets.size/.self$batch.size)
                              .self$last.batch.pad <- .self$batch.size - buckets.size %% .self$batch.size
                              .self$last.batch.pad[.self$last.batch.pad == .self$batch.size] <- 0
                              
                              .self$batch.per.epoch <- sum(.self$batch.per.bucket)
                              # Number of batches per epoch given the batch.size
                              .self$batch.per.epoch <- sum(.self$batch.per.bucket)
                              .self$epoch <- .self$epoch + 1
                              .self$batch <- 0
                              
                              if (.self$shuffle) {
                                set.seed(.self$seed)
                                bucket_plan_names <- sample(rep(names(.self$batch.per.bucket), times = .self$batch.per.bucket))
                                .self$bucket.plan <- ave(bucket_plan_names == bucket_plan_names, bucket_plan_names, 
                                                         FUN = cumsum)
                                names(.self$bucket.plan) <- bucket_plan_names
                                # Return first BucketID at reset for initialization of the model
                                .self$bucketID <- .self$bucket.plan[1]
                                
                                .self$buckets <- lapply(.self$buckets, function(x) {
                                  shuffle_id <- sample(tail(dim(x$data), 1))
                                  if (length(dim(x$label)) == 0) {
                                    list(data = x$data[, shuffle_id], label = x$label[shuffle_id])
                                  } else {
                                    list(data = x$data[, shuffle_id], label = x$label[, shuffle_id])
                                  }
                                })
                              } else {
                                bucket_plan_names <- rep(names(.self$batch.per.bucket), times = .self$batch.per.bucket)
                                .self$bucket.plan <- ave(bucket_plan_names == bucket_plan_names, bucket_plan_names, 
                                                         FUN = cumsum)
                                names(.self$bucket.plan) <- bucket_plan_names
                              }
                            }, iter.next = function() {
                              .self$batch <- .self$batch + 1
                              .self$bucketID <- .self$bucket.plan[batch]
                              return(.self$batch <= .self$batch.per.epoch)
                            }, value = function() {
                              # bucketID is a named integer: the integer indicates the batch id for the given
                              # bucket (used to fetch appropriate samples within the bucket) the name is a
                              # character containing the sequence length of the bucket (used to unroll the rnn
                              # to appropriate sequence length)
                              idx <- (.self$bucketID - 1) * (.self$batch.size) + seq_len(batch.size)
                              
                              # Reuse first idx for padding
                              if (bucketID == .self$batch.per.bucket[names(.self$bucketID)] & !.self$last.batch.pad[names(.self$bucketID)] == 0) {
                                idx <- c(idx[seq_len(.self$batch.size - .self$last.batch.pad[names(.self$bucketID)])], seq_len(.self$last.batch.pad[names(.self$bucketID)]))
                              }
                              
                              data <- .self$buckets[[names(.self$bucketID)]]$data[, idx, drop = F]
                              seq.mask <- as.integer(names(bucketID)) - apply(data==.self$data.mask.element, 2, sum)
                              if (length(dim(.self$buckets[[names(.self$bucketID)]]$label)) == 0) {
                                label <- .self$buckets[[names(.self$bucketID)]]$label[idx]
                              } else {
                                label <- .self$buckets[[names(.self$bucketID)]]$label[, idx, drop = F]
                              }
                              return(list(data = mx.nd.array(data), seq.mask = mx.nd.array(seq.mask), 
                                          label = mx.nd.array(label)))
                            }, num.pad = function() {
                              if (bucketID == .self$batch.per.bucket[names(.self$bucketID)] & !.self$last.batch.pad[names(.self$bucketID)] == 0){
                                return(.self$last.batch.pad[names(.self$bucketID)])
                              } else return(0)
                            }, finalize = function() {
                            }))

# 
#' Create Bucket Iter
#'
#' @param buckets The data array.
#' @param batch.size The batch size used to pack the array.
#' @param data.mask.element The element to mask
#' @param shuffle Whether shuffle the data
#' @param seed The random seed
#'
#' @export
mx.io.bucket.iter <- function(buckets, batch.size, data.mask.element = 0, shuffle = FALSE, 
                              seed = 123) {
  return(BucketIter$new(buckets = buckets, batch.size = batch.size, data.mask.element = data.mask.element, 
                        shuffle = shuffle, seed = seed))
}